{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convolutional Sentiment Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we build a *convolutional* neural net to classify IMDB movie reviews by their sentiment."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "from keras.datasets import imdb\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Embedding\n",
    "from keras.layers import SpatialDropout1D, Conv1D, GlobalMaxPooling1D # new! \n",
    "from keras.callbacks import ModelCheckpoint\n",
    "import os\n",
    "from sklearn.metrics import roc_auc_score \n",
    "import matplotlib.pyplot as plt \n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Set hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# output directory name:\n",
    "output_dir = 'model_output/conv'\n",
    "\n",
    "# training:\n",
    "epochs = 4\n",
    "batch_size = 128\n",
    "\n",
    "# vector-space embedding: \n",
    "n_dim = 64\n",
    "n_unique_words = 5000 \n",
    "max_review_length = 400\n",
    "pad_type = trunc_type = 'pre'\n",
    "drop_embed = 0.2 # new!\n",
    "\n",
    "# convolutional layer architecture:\n",
    "n_conv = 256 # filters, a.k.a. kernels\n",
    "k_conv = 3 # kernel length\n",
    "\n",
    "# dense layer architecture: \n",
    "n_dense = 256\n",
    "dropout = 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_valid, y_valid) = imdb.load_data(num_words=n_unique_words) # removed n_words_to_skip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Preprocess data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x_train = pad_sequences(x_train, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)\n",
    "x_valid = pad_sequences(x_valid, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### Design neural network architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(n_unique_words, n_dim, input_length=max_review_length)) \n",
    "model.add(SpatialDropout1D(drop_embed))\n",
    "model.add(Conv1D(n_conv, k_conv, activation='relu'))\n",
    "# model.add(Conv1D(n_conv, k_conv, activation='relu'))\n",
    "model.add(GlobalMaxPooling1D())\n",
    "model.add(Dense(n_dense, activation='relu'))\n",
    "model.add(Dropout(dropout))\n",
    "model.add(Dense(1, activation='sigmoid'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, 400, 64)           320000    \n",
      "_________________________________________________________________\n",
      "spatial_dropout1d_1 (Spatial (None, 400, 64)           0         \n",
      "_________________________________________________________________\n",
      "conv1d_1 (Conv1D)            (None, 398, 256)          49408     \n",
      "_________________________________________________________________\n",
      "global_max_pooling1d_1 (Glob (None, 256)               0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 256)               65792     \n",
      "_________________________________________________________________\n",
      "dropout_1 (Dropout)          (None, 256)               0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 1)                 257       \n",
      "=================================================================\n",
      "Total params: 435,457\n",
      "Trainable params: 435,457\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Configure model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "modelcheckpoint = ModelCheckpoint(filepath=output_dir+\"/weights.{epoch:02d}.hdf5\")\n",
    "if not os.path.exists(output_dir):\n",
    "    os.makedirs(output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 25000 samples, validate on 25000 samples\n",
      "Epoch 1/4\n",
      "25000/25000 [==============================] - 45s - loss: 0.4800 - acc: 0.7472 - val_loss: 0.2931 - val_acc: 0.8754\n",
      "Epoch 2/4\n",
      "25000/25000 [==============================] - 3s - loss: 0.2502 - acc: 0.8974 - val_loss: 0.2602 - val_acc: 0.8918\n",
      "Epoch 3/4\n",
      "25000/25000 [==============================] - 3s - loss: 0.1719 - acc: 0.9357 - val_loss: 0.2811 - val_acc: 0.8848\n",
      "Epoch 4/4\n",
      "25000/25000 [==============================] - 3s - loss: 0.1147 - acc: 0.9609 - val_loss: 0.3094 - val_acc: 0.8834\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f99e7711f60>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 89.1% validation accuracy in epoch 2\n",
    "# ...with second convolutional layer is essentially the same at 89.0%\n",
    "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_valid, y_valid), callbacks=[modelcheckpoint])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_weights(output_dir+\"/weights.01.hdf5\") # zero-indexed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24544/25000 [============================>.] - ETA: 0s"
     ]
    }
   ],
   "source": [
    "y_hat = model.predict_proba(x_valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAD4FJREFUeJzt3X+sX3ddx/Hni5WB/NyghWBbvSMUpJAYlpsxJEGkZD/J\nuj82UyJSSGMTnIhI1KF/1ABLhr+GJPyw0mkhSDcncQ2bLnM/gho36BhOtrmsbnOrm+xCu6IugIW3\nf3w/m3fjtvfc9t7vd3ef5yNp7jmf8znf83n33t7X93zO+Z6mqpAk9ecZkx6AJGkyDABJ6pQBIEmd\nMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSp1ZMegBHsnLlypqampr0MKQf9Z27Rl9f8KrJjkOa\nwy233PKtqlo1X7+ndABMTU2xZ8+eSQ9D+lF/9+bR17feOMlRSHNK8u9D+jkFJEmdMgAkqVMGgCR1\nygCQpE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnXpKfxJYkiZp6sKrJnbs+y4+e8mP4RmAJHXKAJCk\nThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwASeqU\nASBJnTIAJKlTBoAkdcoAkKRODQqAJO9PcnuSbyT5QpJnJzkpyc1J7k5yWZLjW99ntfW9bfvUrNf5\nYGu/K8npS1OSJGmIeQMgyWrgV4HpqnotcBywCfgocElVrQMOAFvaLluAA1X1CuCS1o8k69t+rwHO\nAD6Z5LjFLUeSNNTQKaAVwI8lWQE8B3gIeAtwRdu+Ezi3LW9s67TtG5Kkte+qqu9V1b3AXuCUYy9B\nknQ05g2AqvoP4A+A+xn94j8I3AI8UlWHWrd9wOq2vBp4oO17qPV/8ez2OfaRJI3ZkCmgExm9ez8J\n+HHgucCZc3Stx3Y5zLbDtT/5eFuT7EmyZ2ZmZr7hSZKO0ooBfd4K3FtVMwBJvgj8DHBCkhXtXf4a\n4MHWfx+wFtjXpoxeCOyf1f6Y2fs8rqq2A9sBpqenfyQgFmLqwquOZfejdt/FZ0/kuJK0EEOuAdwP\nnJrkOW0ufwNwB3ADcF7rsxm4si3vbuu07ddXVbX2Te0uoZOAdcBXFqcMSdJCzXsGUFU3J7kC+Bpw\nCLiV0Tv0q4BdST7S2na0XXYAn0uyl9E7/03tdW5Pcjmj8DgEXFBVP1jkeiRJAw2ZAqKqtgHbntR8\nD3PcxVNV3wXOP8zrXARctMAxSpKWgJ8ElqROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaA\nJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhS\npwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXK\nAJCkThkAktQpA0CSOmUASFKnDABJ6tSgAEhyQpIrkvxrkjuTvCHJi5Jcm+Tu9vXE1jdJPp5kb5Lb\nkpw863U2t/53J9m8VEVJkuY39Azgj4G/raqfAn4auBO4ELiuqtYB17V1gDOBde3PVuBTAEleBGwD\nXg+cAmx7LDQkSeM3bwAkeQHwJmAHQFV9v6oeATYCO1u3ncC5bXkj8NkauQk4IcnLgNOBa6tqf1Ud\nAK4FzljUaiRJgw05A3g5MAP8WZJbk3wmyXOBl1bVQwDt60ta/9XAA7P239faDtcuSZqAIQGwAjgZ\n+FRVvQ74H/5/umcumaOtjtD+xJ2TrUn2JNkzMzMzYHiSpKMxJAD2Afuq6ua2fgWjQPhmm9qhfX14\nVv+1s/ZfAzx4hPYnqKrtVTVdVdOrVq1aSC2SpAWYNwCq6j+BB5K8qjVtAO4AdgOP3cmzGbiyLe8G\n3tnuBjoVONimiK4BTktyYrv4e1prkyRNwIqB/d4LfD7J8cA9wLsZhcflSbYA9wPnt75XA2cBe4FH\nW1+qan+SDwNfbf0+VFX7F6UKSdKCDQqAqvo6MD3Hpg1z9C3ggsO8zqXApQsZoCRpafhJYEnqlAEg\nSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLU\nKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0y\nACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQpE4NDoAkxyW5\nNcmX2vpJSW5OcneSy5Ic39qf1db3tu1Ts17jg639riSnL3YxkqThFnIG8D7gzlnrHwUuqap1wAFg\nS2vfAhyoqlcAl7R+JFkPbAJeA5wBfDLJccc2fEnS0RoUAEnWAGcDn2nrAd4CXNG67ATObcsb2zpt\n+4bWfyOwq6q+V1X3AnuBUxajCEnSwg09A/gY8JvAD9v6i4FHqupQW98HrG7Lq4EHANr2g63/4+1z\n7CNJGrN5AyDJ24CHq+qW2c1zdK15th1pn9nH25pkT5I9MzMz8w1PknSUhpwBvBE4J8l9wC5GUz8f\nA05IsqL1WQM82Jb3AWsB2vYXAvtnt8+xz+OqantVTVfV9KpVqxZckCRpmHkDoKo+WFVrqmqK0UXc\n66vqF4AbgPNat83AlW15d1unbb++qqq1b2p3CZ0ErAO+smiVSJIWZMX8XQ7rt4BdST4C3ArsaO07\ngM8l2cvonf8mgKq6PcnlwB3AIeCCqvrBMRxfknQMFhQAVXUjcGNbvoc57uKpqu8C5x9m/4uAixY6\nSEnS4vOTwJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcM\nAEnqlAEgSZ0yACSpUwaAJHXKAJCkThkAktQpA0CSOmUASFKnDABJ6pQBIEmdMgAkqVMGgCR1ygCQ\npE4ZAJLUKQNAkjplAEhSpwwASeqUASBJnTIAJKlTBoAkdcoAkKROGQCS1CkDQJI6ZQBIUqcMAEnq\nlAEgSZ2aNwCSrE1yQ5I7k9ye5H2t/UVJrk1yd/t6YmtPko8n2ZvktiQnz3qtza3/3Uk2L11ZkqT5\nDDkDOAR8oKpeDZwKXJBkPXAhcF1VrQOua+sAZwLr2p+twKdgFBjANuD1wCnAtsdCQ5I0fvMGQFU9\nVFVfa8v/BdwJrAY2Ajtbt53AuW15I/DZGrkJOCHJy4DTgWuran9VHQCuBc5Y1GokSYMt6BpAking\ndcDNwEur6iEYhQTwktZtNfDArN32tbbDtUuSJmBwACR5HvBXwK9V1XeO1HWOtjpC+5OPszXJniR7\nZmZmhg5PkrRAgwIgyTMZ/fL/fFV9sTV/s03t0L4+3Nr3AWtn7b4GePAI7U9QVdurarqqpletWrWQ\nWiRJCzDkLqAAO4A7q+qPZm3aDTx2J89m4MpZ7e9sdwOdChxsU0TXAKclObFd/D2ttUmSJmDFgD5v\nBH4R+JckX29tvw1cDFyeZAtwP3B+23Y1cBawF3gUeDdAVe1P8mHgq63fh6pq/6JUIUlasHkDoKr+\ngbnn7wE2zNG/gAsO81qXApcuZICSpKXhJ4ElqVMGgCR1ygCQpE4NuQgsSRM1deFVkx7C05JnAJLU\nKc8AlsCk3q3cd/HZEzmupOXJMwBJ6pQBIEmdMgAkqVMGgCR1ygCQpE4ZAJLUKQNAkjplAEhSpwwA\nSeqUASBJnTIAJKlTPgtI0mA+lfPpxTMASeqUASBJnXIK6GlkkqfnPopaWn48A5CkThkAktQpp4C0\nKPxf0MbHO3G0WDwDkKROeQagZW1S74Z3vfzbAGzy3biWMc8AJKlTBoAkdcoAkKROGQCS1CkDQJI6\nZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXKAJCkTo09AJKckeSuJHuTXDju40uSRsYaAEmOAz4BnAms\nB96eZP04xyBJGhn3GcApwN6quqeqvg/sAjaOeQySJMYfAKuBB2at72ttkqQxG/f/B5A52uoJHZKt\nwNa2+t9J7jqG460EvnUM+y83vdULE6r5DY8vvW3chwa/z13IR4+p5p8c0mncAbAPWDtrfQ3w4OwO\nVbUd2L4YB0uyp6qmF+O1loPe6gVr7oU1L41xTwF9FViX5KQkxwObgN1jHoMkiTGfAVTVoSS/AlwD\nHAdcWlW3j3MMkqSRsf+fwFV1NXD1mA63KFNJy0hv9YI198Kal0Cqav5ekqSnHR8FIUmdWvYBMN+j\nJZI8K8llbfvNSabGP8rFNaDmX09yR5LbklyXZNAtYU9lQx8hkuS8JJVk2d8xMqTmJD/fvte3J/mL\ncY9xsQ342f6JJDckubX9fJ81iXEuliSXJnk4yTcOsz1JPt7+Pm5LcvKiDqCqlu0fRheS/w14OXA8\n8M/A+if1+WXg0215E3DZpMc9hpp/DnhOW35PDzW3fs8HvgzcBExPetxj+D6vA24FTmzrL5n0uMdQ\n83bgPW15PXDfpMd9jDW/CTgZ+MZhtp8F/A2jz1CdCty8mMdf7mcAQx4tsRHY2ZavADYkmesDacvF\nvDVX1Q1V9WhbvYnR5y2Ws6GPEPkw8HvAd8c5uCUypOZfAj5RVQcAqurhMY9xsQ2puYAXtOUX8qTP\nES03VfVlYP8RumwEPlsjNwEnJHnZYh1/uQfAkEdLPN6nqg4BB4EXj2V0S2Ohj9PYwugdxHI2b81J\nXgesraovjXNgS2jI9/mVwCuT/GOSm5KcMbbRLY0hNf8u8I4k+xjdTfje8QxtYpb08Tljvw10kc37\naImBfZaTwfUkeQcwDfzsko5o6R2x5iTPAC4B3jWuAY3BkO/zCkbTQG9mdJb390leW1WPLPHYlsqQ\nmt8O/HlV/WGSNwCfazX/cOmHNxFL+vtruZ8BzPtoidl9kqxgdNp4pFOup7ohNZPkrcDvAOdU1ffG\nNLalMl/NzwdeC9yY5D5Gc6W7l/mF4KE/21dW1f9W1b3AXYwCYbkaUvMW4HKAqvon4NmMnhP0dDXo\n3/vRWu4BMOTREruBzW35POD6aldXlql5a27TIX/C6Jf/cp8XhnlqrqqDVbWyqqaqaorRdY9zqmrP\nZIa7KIb8bP81owv+JFnJaEronrGOcnENqfl+YANAklczCoCZsY5yvHYD72x3A50KHKyqhxbrxZf1\nFFAd5tESST4E7Kmq3cAORqeJexm98980uREfu4E1/z7wPOAv2/Xu+6vqnIkN+hgNrPlpZWDN1wCn\nJbkD+AHwG1X17cmN+tgMrPkDwJ8meT+jqZB3Lec3dEm+wGgKb2W7rrENeCZAVX2a0XWOs4C9wKPA\nuxf1+Mv4706SdAyW+xSQJOkoGQCS1CkDQJI6ZQBIUqcMAEnqlAEgSZ0yACSpUwaAJHXq/wDB1xm9\njnP+YAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f99e74ef2b0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(y_hat)\n",
    "_ = plt.axvline(x=0.5, color='orange')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'95.97'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"{:0.2f}\".format(roc_auc_score(y_valid, y_hat)*100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
